Add use_linear option to replace Conv3d tokenizer with Linear layers#48
Add use_linear option to replace Conv3d tokenizer with Linear layers#48nicholasmalaya wants to merge 1 commit intoORNL:mainfrom
Conversation
When kernel_size == stride (non-overlapping patches), Conv3d is mathematically equivalent to reshape + nn.Linear. This avoids the im2col/col2im overhead and replaces MIOpen's implicit GEMM backward-weight path with standard rocBLAS matmul backward. Profiling on MI355X (gfx950) shows the backward-weight GEMM (kernel_batched_gemm_xdlops_bwd_weight) consumed 79.3% of compute time. With use_linear=True, this kernel is eliminated entirely, yielding a 2.87x end-to-end training speedup with identical loss convergence. Enabled via config: use_linear: !!bool True (default False, fully backward compatible). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
pzhanggit
left a comment
There was a problem hiding this comment.
@nicholasmalaya thank you very much for the optimization, Nick!
@TsChala the PR looks good to me. Could you do a test run on Frontier when it's back from maintenance? We should extend the changes to other models for better performance as well. Thanks
|
Thanks for the edits @nicholasmalaya ! @pzhanggit I ran some test on the JHUTDB dataset today. Using the Turbulence Transformer I see around 2x speed-up! This is only from the For the |
When kernel_size == stride (non-overlapping patches), Conv3d is mathematically equivalent to reshape + nn.Linear. This avoids the im2col/col2im overhead and replaces MIOpen's implicit GEMM backward-weight path with standard rocBLAS matmul backward.
Profiling on MI355X (gfx950) shows the backward-weight GEMM (kernel_batched_gemm_xdlops_bwd_weight) consumed 79.3% of compute time. With use_linear=True, this kernel is eliminated entirely, yielding a 2.87x end-to-end training speedup with identical loss convergence.
Enabled via config: use_linear: !!bool True (default False, fully backward compatible).